import pandas as pd
import numpy as np
import torch
from matplotlib import pyplot as plt
from matplotlib.patches import FancyArrowPatch

def process_features_and_trips(df_trips_inner, df_features):
    df_trips_inner['step'] = df_trips_inner.groupby(
        ['driver','trip_id_new','trip_id_cut','cluster']).node_id.transform('cumcount')
    df_trips_inner['id_inner_trip'] = df_trips_inner.groupby(['driver','trip_id_new','trip_id_cut']).ngroup()
    df_features = df_features.merge(
        df_trips_inner[['id_inner_trip','driver','trip_id','trip_id_new','trip_id_cut','cluster']].drop_duplicates(), 
        on=['driver','trip_id'], how='inner')
    df_trips_inner = df_trips_inner[['id_inner_trip','cluster','node_id','step']]
    df_features = df_features[['id_inner_trip','cluster','day_of_Week','is_Holiday','time_start']]
    df_features = pd.get_dummies(df_features, columns=['day_of_Week'])
    df_features['time_start'] = (df_features['time_start'] - df_features['time_start'].min()) / (df_features['time_start'].max() - df_features['time_start'].min())
    df_features = df_features.drop_duplicates(subset=['id_inner_trip'], keep='first')
    return df_trips_inner, df_features

def sort_nodes_in_cl(df_nodes, clusters):
    map_cluster_sort = dict(zip(clusters, np.arange(len(clusters))))
    df_nodes_in_cluster = df_nodes[df_nodes['cluster'].isin(clusters)].reset_index(drop=True)
    df_nodes_in_cluster['cluster_sorted'] = df_nodes_in_cluster['cluster'].map(map_cluster_sort)
    df_nodes_in_cluster = df_nodes_in_cluster.sort_values(by=['cluster_sorted','node_id']).reset_index(drop=True)
    df_nodes_in_cluster['node_sorted'] = np.arange(len(df_nodes_in_cluster))
    map_nodes = dict(np.array(df_nodes_in_cluster[['node_id','node_sorted']]))
    return df_nodes_in_cluster, map_nodes

def haversine(lat1, lon1, lat2, lon2):
    R = 6371.0  # Radius of the Earth in kilometers
    dlat = np.radians(lat2 - lat1)
    dlon = np.radians(lon2 - lon1)
    a = np.sin(dlat / 2) ** 2 + np.cos(np.radians(lat1)) * np.cos(np.radians(lat2)) * np.sin(dlon / 2) ** 2
    c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a))
    distance = R * c
    return distance

def create_prior_distance_matrix(nodes_df, df_edges):
    nodes_df = nodes_df.sort_values(by='node_sorted').reset_index(drop=True)
    # Extract and sort unique node IDs for edges
    sorted_node_ids = np.sort(nodes_df['node_sorted'].values)

    # Initialize adjacency matrix with zeros
    n = len(sorted_node_ids)
    adj_matrix_sorted = np.zeros((n, n), dtype=int)

    # Populate the adjacency matrix using the mapping
    for _, row in df_edges.iterrows():
        i = row['node_from']
        j = row['node_to']
        adj_matrix_sorted[i, j] = 1

    distance_matrix = np.full((n, n), 5000.)

    for i in range(n):
        for j in range(n):
            if adj_matrix_sorted[i, j] == 1:
                lat1, lon1 = nodes_df.loc[i, ['node_lat', 'node_lon']]
                lat2, lon2 = nodes_df.loc[j, ['node_lat', 'node_lon']]
                distance_matrix[i, j] = haversine(lat1, lon1, lat2, lon2)
                
    return adj_matrix_sorted, distance_matrix

def get_prior_and_M_indices(nodes, edges):
    bin_M_in_cluster, prior_M_in_cluster = create_prior_distance_matrix(nodes, edges)
    prior_M_in_cluster = torch.tensor(prior_M_in_cluster, dtype=torch.float32)
    E = int(bin_M_in_cluster.sum())
    M_indices = np.zeros((E, 2))
    M_indices[:,0], M_indices[:,1] = np.where(bin_M_in_cluster==1)
    M_indices = torch.tensor(M_indices, dtype=torch.long)
    edges_prior_in_cluster = prior_M_in_cluster[M_indices[:, 0], M_indices[:, 1]]
    return prior_M_in_cluster, edges_prior_in_cluster, M_indices





    


def get_m_combined(node_idx_sequence_trip, map_bridges, nodes_in_cluster_sorted, 
                   nodes_inner_A, nodes_inner_B, nodes_bridge_A, nodes_bridge_B, nodes_bridge,
                   V, Vk):
    m_combined = np.zeros((V, V, Vk))
    
    subpaths = []
    
    for i in range(len(node_idx_sequence_trip)-1):
        for j in range(i+1, len(node_idx_sequence_trip)):
       
            #check_inner_A = set(node_idx_sequence_trip[i:j+1]).intersection(
            #    nodes_inner_A.union(nodes_bridge_A)) == set(node_idx_sequence_trip[i:j+1])
        
            #check_inner_B = set(node_idx_sequence_trip[i:j+1]).intersection(
            #    nodes_inner_B.union(nodes_bridge_B)) == set(node_idx_sequence_trip[i:j+1])

            #if node_idx_sequence_trip[i] == 53  and node_idx_sequence_trip[j] == 2:
            #    import pdb
            #    pdb.set_trace()
            
            only_borders = set(node_idx_sequence_trip[i+1:j]).intersection(nodes_bridge)
                        
            if len(only_borders) > 0:  
            #if (not check_inner_A) and (not check_inner_B):
                if j-i == 1:
                    max_node = node_idx_sequence_trip[i]                    
                else:
                    try:
                        max_node = max(only_borders)
                    except:
                        import pdb
                        pdb.set_trace()
                subpath = [node_idx_sequence_trip[i], max_node, node_idx_sequence_trip[j]]
                subpaths.append(subpath)
                
            else:
                if j-i == 1 and len(set(node_idx_sequence_trip[i:j]).intersection(nodes_bridge))==1:
                    subpath = [node_idx_sequence_trip[i], node_idx_sequence_trip[i], node_idx_sequence_trip[j]]
                else:
                    # Considering it stays inside the cluster
                    subpath = [node_idx_sequence_trip[i], max(nodes_in_cluster_sorted)+1, node_idx_sequence_trip[j]]
                subpaths.append(subpath)

    for subpath in subpaths:
        i, k, j = subpath
        m_combined[i, j, map_bridges[k]] = 1.
        
    return m_combined



def get_m_combined(node_idx_sequence_trip, map_bridges, nodes_in_cluster_sorted, 
                   nodes_inner_A, nodes_inner_B, nodes_bridge_A, nodes_bridge_B, nodes_bridge,
                   V, Vk):
    m_combined = np.zeros((V, V, Vk))
    
    subpaths = []
    
    for i in range(len(node_idx_sequence_trip)-1):
        for j in range(i+1, len(node_idx_sequence_trip)):
       
            check_inner_A = set(node_idx_sequence_trip[i:j+1]).intersection(
                nodes_inner_A.union(nodes_bridge_A)) == set(node_idx_sequence_trip[i:j+1])
        
            check_inner_B = set(node_idx_sequence_trip[i:j+1]).intersection(
                nodes_inner_B.union(nodes_bridge_B)) == set(node_idx_sequence_trip[i:j+1])

            #if node_idx_sequence_trip[i] == 81  and node_idx_sequence_trip[j] == 87:
            #    import pdb
            #    pdb.set_trace()
            
            only_borders = set(node_idx_sequence_trip[i+1:j]).intersection(nodes_bridge)
                        
            #if len(only_borders) > 0:  
            if (not check_inner_A) and (not check_inner_B):
                if j-i == 1:
                    max_node = node_idx_sequence_trip[i]                    
                else:
                    try:
                        max_node = max(only_borders)
                    except:
                        import pdb
                        pdb.set_trace()
                subpath = [node_idx_sequence_trip[i], max_node, node_idx_sequence_trip[j]]
                subpaths.append(subpath)
                
            else:
                #if j-i == 1 and len(set(node_idx_sequence_trip[i:j]).intersection(nodes_bridge))==1:
                #    subpath = [node_idx_sequence_trip[i], node_idx_sequence_trip[i], node_idx_sequence_trip[j]]
                #else:
                    # Considering it stays inside the cluster
                subpath = [node_idx_sequence_trip[i], max(nodes_in_cluster_sorted)+1, node_idx_sequence_trip[j]]
                subpaths.append(subpath)

    for subpath in subpaths:
        i, k, j = subpath
        m_combined[i, j, map_bridges[k]] = 1.
        
    return m_combined


def get_m_combined(node_idx_sequence_trip, map_bridges, nodes_in_cluster_sorted, 
                   dict_node_to_cluster, nodes_bridge,
                   V, Vk):
    m_combined = np.zeros((V, V, Vk))
    
    subpaths = []
    
    for i in range(len(node_idx_sequence_trip)-1):
        for j in range(i+1, len(node_idx_sequence_trip)):

            clusters_seq = np.vectorize(dict_node_to_cluster.get)(node_idx_sequence_trip[i:j+1])  
            check_inner = np.unique(clusters_seq).size == 1
            

            #if node_idx_sequence_trip[i] == 81  and node_idx_sequence_trip[j] == 87:
            #    import pdb
            #    pdb.set_trace()
            
            only_borders = set(node_idx_sequence_trip[i+1:j]).intersection(nodes_bridge)
                        
            #if len(only_borders) > 0:  
            if (not check_inner):
                if j-i == 1:
                    max_node = node_idx_sequence_trip[i]                    
                else:
                    try:
                        max_node = max(only_borders)
                    except:
                        import pdb
                        pdb.set_trace()
                subpath = [node_idx_sequence_trip[i], max_node, node_idx_sequence_trip[j]]
                subpaths.append(subpath)
                
            else:
                #if j-i == 1 and len(set(node_idx_sequence_trip[i:j]).intersection(nodes_bridge))==1:
                #    subpath = [node_idx_sequence_trip[i], node_idx_sequence_trip[i], node_idx_sequence_trip[j]]
                #else:
                    # Considering it stays inside the cluster
                subpath = [node_idx_sequence_trip[i], max(nodes_in_cluster_sorted)+1, node_idx_sequence_trip[j]]
                subpaths.append(subpath)

    for subpath in subpaths:
        i, k, j = subpath
        
        try:
            m_combined[i, j, map_bridges[k]] = 1.
        except:
            import pdb
            pdb.set_trace()
        
    return m_combined


def get_m_combined_batch(node_idx_sequence_trips, map_bridges, nodes_in_cluster_sorted, 
                         nodes_inner_A, nodes_inner_B, nodes_bridge_A, nodes_bridge_B, nodes_bridge,
                         idcs_batch, V, Vk):
    
    B1 = idcs_batch.shape[0]
    B2 = idcs_batch.shape[1]
    
    m_combined_batch = np.zeros((B1, B2, V, V, Vk))  
    
    for i in range(0, B1):
        for j in range(0, B2):
            id_trip = idcs_batch[i,j]
            m_combined_batch[i,j,:,:,:] = get_m_combined(
                node_idx_sequence_trips[id_trip], map_bridges, nodes_in_cluster_sorted, 
                nodes_inner_A, nodes_inner_B, nodes_bridge_A, nodes_bridge_B, nodes_bridge,
                V, Vk)
            
    return m_combined_batch


def get_m_combined_batch(node_idx_sequence_trips, map_bridges, nodes_in_cluster_sorted, dict_node_to_cluster,
                         nodes_bridge, idcs_batch, V, Vk):
    
    B1 = idcs_batch.shape[0]
    B2 = idcs_batch.shape[1]
    
    m_combined_batch = np.zeros((B1, B2, V, V, Vk))  
    
    for i in range(0, B1):
        for j in range(0, B2):
            id_trip = idcs_batch[i,j]
            m_combined_batch[i,j,:,:,:] = get_m_combined(
                node_idx_sequence_trips[id_trip], map_bridges, nodes_in_cluster_sorted, 
                   dict_node_to_cluster, nodes_bridge,
                   V, Vk)
            
            
    return m_combined_batch



def get_m_inter(node_idx_sequence_trip, V, Vk):
    m_inter = np.zeros((V, V, Vk))
    
    subpaths = []
    
    for i in range(len(node_idx_sequence_trip)-1):
        for j in range(i+1, len(node_idx_sequence_trip)):
            if j-i == 1:
                subpath = [node_idx_sequence_trip[i], node_idx_sequence_trip[i], node_idx_sequence_trip[j]]
            else:
                max_node = max(node_idx_sequence_trip[i+1:j])
                subpath = [node_idx_sequence_trip[i], max_node, node_idx_sequence_trip[j]]
            subpaths.append(subpath)

    for subpath in subpaths:
        i, k, j = subpath
        m_inter[i, j, k] = 1.
        
    return m_inter


def get_m_inter_batch(node_idx_sequence_trips, idcs_batch, V, Vk):
    
    B1 = idcs_batch.shape[0]
    B2 = idcs_batch.shape[1]
    
    m_inter_batch = np.zeros((B1, B2, V, V, Vk))  
    
    for i in range(0, B1):
        for j in range(0, B2):
            id_trip = idcs_batch[i,j]
            m_inter_batch[i,j,:,:,:] = get_m_inter(node_idx_sequence_trips[id_trip], V, Vk)
            
    return m_inter_batch







def interm_node(probs, sn, en, probabilistic):
    interm_probs = probs[sn,en]
    
    if probabilistic:
        interm_probs = interm_probs / interm_probs.sum()
        chosen_node = np.random.choice(np.arange(len(interm_probs)), p=interm_probs)
    else:
        chosen_node = np.argmax(interm_probs)

    return chosen_node

def random_path(probs, sn, en, visited_nodes, probabilistic):
    if visited_nodes is None:
        visited_nodes = set([sn, en])
    
    internode = interm_node(probs, sn, en, probabilistic)

    # Base condition: If the internode is the same as start or end node, or it's already visited, we stop the recursion
    if internode == sn or internode == en or internode in visited_nodes:
        return [sn, en]

    # Mark internode as visited
    visited_nodes.add(internode)
    
    left_path = random_path(probs, sn, internode, visited_nodes, probabilistic)
    right_path = random_path(probs, internode, en, visited_nodes, probabilistic)

    # Merge the two paths excluding one of the repeated internodes
    return left_path[:-1] + right_path

def get_path_given_probs(probs_sample, sn, en, probabilistic):  
    opt_path = random_path(probs_sample, sn, en, 
                           None, probabilistic)
    edges_sequence_opt = np.column_stack([opt_path[:-1], opt_path[1:]])
    return edges_sequence_opt

def find_next_edge(current_node, edges):
    for edge in edges:
        if edge[0] == current_node:
            return edge
    return None

def sort_nodes_path(start_node, end_node, edges):
    path_nodes = [start_node]
    current_node = start_node
    while current_node != end_node:
        next_edge = find_next_edge(current_node, edges)
        if next_edge is None:
            print("No path found!")
            break
        path_nodes.append(int(next_edge[1]))
        current_node = next_edge[1]
    return path_nodes

def get_lon_lat_from_nodes(path_nodes, df_nodes_in_cluster):
    filtered_df = df_nodes_in_cluster.set_index('node_sorted').loc[path_nodes].reset_index()
    latitudes = filtered_df['node_lat'].to_numpy()
    longitudes = filtered_df['node_lon'].to_numpy()
    return longitudes, latitudes

def get_lat_lon_comparison(probs, sample, sn, en, edges_seq_full, df_nodes_in_cluster):
    
    probs_sample = probs[sample]

    path_pred = get_path_given_probs(probs_sample, sn, en, False)

    path_nodes_pred = sort_nodes_path(sn, en, path_pred)
    path_nodes_true = sort_nodes_path(sn, en, edges_seq_full[sample])

    lon_pred, lat_pred = get_lon_lat_from_nodes(path_nodes_pred, df_nodes_in_cluster)
    lon_true, lat_true = get_lon_lat_from_nodes(path_nodes_true, df_nodes_in_cluster)
    
    return lon_pred, lat_pred, lon_true, lat_true


def get_lat_lon_comparison_2(probs, sample, sn, en, edges_seq_full, df_nodes_in_cluster):
 
    path_nodes_pred = sort_nodes_path(sn, en, probs[sample])
    path_nodes_true = sort_nodes_path(sn, en, edges_seq_full[sample])

    lon_pred, lat_pred = get_lon_lat_from_nodes(path_nodes_pred, df_nodes_in_cluster)
    lon_true, lat_true = get_lon_lat_from_nodes(path_nodes_true, df_nodes_in_cluster)
    
    return lon_pred, lat_pred, lon_true, lat_true


def plot_arrow(ax, lon, lat, alpha, color, width):

    delta_lon_total = lon[0] - lon[-1]
    delta_lat_total = lat[0] - lat[-1]
    size = np.sqrt(delta_lon_total**2 + delta_lat_total**2)
    
    lon[0]
    for k in range(len(lon)-1):
        start_lon = lon[k]
        start_lat = lat[k]
        delta_lon = lon[k+1] - start_lon
        delta_lat = lat[k+1] - start_lat
        
        size_arrow = np.sqrt(delta_lon**2 + delta_lat**2)
        
        ax.arrow(
            start_lon, start_lat, 0.85*delta_lon, 0.85*delta_lat, 
            color=color, alpha=alpha, width=width*size, head_width=0.1*size_arrow, head_length=0.1*size_arrow, overhang=0.2)
        

def plot_path_comparisons(probs, samples, sn_original, en_original, edges_seq_full, df_nodes_in_cluster, fw=True):

    n_samples = len(samples)
    n_rows = int(np.ceil(n_samples/3))    
        
    fig, axs = plt.subplots(n_rows, 3, figsize=(10,8), dpi=100)
    for i in range(0, n_rows):
        for j in range(0, 3):
            sample_to_plot = 3*i + j
            if sample_to_plot>=n_samples:
                continue
                
            sn = sn_original[samples[sample_to_plot]]
            en = en_original[samples[sample_to_plot]]
            
            if fw:
                lon_pred, lat_pred, lon_true, lat_true = get_lat_lon_comparison(
                    probs, samples[sample_to_plot], sn, en, edges_seq_full, df_nodes_in_cluster)
            else:
                lon_pred, lat_pred, lon_true, lat_true = get_lat_lon_comparison_2(
                    probs, samples[sample_to_plot], sn, en, edges_seq_full, df_nodes_in_cluster)
                         
            #axs[i, j].plot(lon_true, lat_true, color='r', alpha=0.3, linewidth=5)
            
            axs[i, j].scatter(lon_true, lat_true, color='r',alpha=0.7)
            #axs[i, j].plot(lon_pred, lat_pred, color='black')
            
            axs[i, j].scatter(lon_pred, lat_pred)
            
            plot_arrow(axs[i, j], lon_true, lat_true, alpha=0.3, color='r', width=0.02)
            plot_arrow(axs[i, j], lon_pred, lat_pred, alpha=0.6, color='black', width=0.01)
            
            node_seq = edges_seq_full[samples[sample_to_plot]]
            
            for k in range(0, len(lon_true)-1):             
                axs[i, j].annotate(node_seq[k, 0], (lon_true[k], lat_true[k]),
                       textcoords="offset points", xytext=(5,5), ha='center')                
            axs[i, j].annotate(node_seq[-1, 1], (lon_true[-1], lat_true[-1]),
                       textcoords="offset points", xytext=(5,5), ha='center')
            
            
            xax = axs[i, j].axes.get_xaxis()
            xax = xax.set_visible(False)
            yax = axs[i, j].axes.get_yaxis()
            yax = yax.set_visible(False)
    plt.show()